package com.auditlog.sql.util;

import com.auditlog.util.ExpressionUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.arithmetic.Addition;
import net.sf.jsqlparser.expression.operators.arithmetic.Division;
import net.sf.jsqlparser.expression.operators.arithmetic.Multiplication;
import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;

import java.sql.SQLException;
import java.util.HashSet;
import java.util.Set;

@Slf4j
public class ConstantsExpressionUtil {
    private static final Set<Class> CONSTANT_EXPRESSIONS = new HashSet<>(8);
    private static final Set<Class> CALCULATE_EXPRESSIONS = new HashSet<>(8);


    static {
        CONSTANT_EXPRESSIONS.add(LongValue.class);
        CONSTANT_EXPRESSIONS.add(StringValue.class);
        CONSTANT_EXPRESSIONS.add(DoubleValue.class);
        // 这个是直接设置值，所以纳入常量
        CONSTANT_EXPRESSIONS.add(JdbcParameter.class);


        CALCULATE_EXPRESSIONS.add(Division.class);
        CALCULATE_EXPRESSIONS.add(Multiplication.class);
        CALCULATE_EXPRESSIONS.add(Addition.class);
        CALCULATE_EXPRESSIONS.add(Subtraction.class);

    }

    /**
     * 是否是常量表达式
     *
     * @param expression
     * @return: boolean
     */
    public static boolean isConstant(Expression expression) {
        return CONSTANT_EXPRESSIONS.contains(expression.getClass()) || canCalculate(expression);
    }

    /**
     * ExpressionList中所有的expression是否都是常量或常量表达式
     *
     * @param expressionList
     * @return: boolean
     */
    public static boolean isAllConstant(ExpressionList expressionList) {
        return expressionList.getExpressions().stream().allMatch(expression -> isConstant(expression));
    }

    public static boolean isAllConstant(MultiExpressionList multiExpressionList) {
        return multiExpressionList.getExpressionLists().stream().allMatch(expressionList -> isAllConstant(expressionList));
    }

    /**
     * 判断是否是表达式
     *
     * @param expression
     * @return: boolean
     */
    public static boolean canCalculate(Expression expression) {
        boolean canCal = false;
        if (CALCULATE_EXPRESSIONS.contains(expression.getClass())) {
            try {
                ExpressionUtils.getValue(expression.toString());
                canCal = true;
            } catch (Exception e) {
                // expression中可能是 a/2或者?/22这种含有不确定值的计算，会抛出异常
                log.error("计算表达式：{}，出错", expression.toString(), e);
            }
        }
        return canCal;
    }

    public static Object getValue(Expression expression) throws SQLException {
        if (expression instanceof LongValue) {
            return ((LongValue) expression).getValue();
        } else if (expression instanceof DoubleValue) {
            return ((DoubleValue) expression).getValue();
        } else if (expression instanceof StringValue) {
            return ((StringValue) expression).getValue();
        } else if (canCalculate(expression)) {
            try {
                return ExpressionUtils.getValue(expression.toString());
            } catch (Exception e) {
                throw new SQLException(e);
            }
        } else {
            throw new SQLException("不能获取" + expression + "的值");
        }
    }
}
